import torch as T
import torch.nn as nn

from custom_models.neural_causal.scm.nn.normalizing_flow import NF


class Continuous(nn.Module):
    def __init__(self, v_size, u_size, o_size):
        super().__init__()
        self.v = sorted(v_size)
        self.u = sorted(u_size)
        self.v_size = v_size
        self.u_size = u_size
        self.o_size = o_size
        i = sum(self.v_size[k] for k in self.v_size) + sum(
            self.u_size[k] for k in self.u_size
        )
        self.nf = NF(o_size, i, K=6)  # TODO how to set K

    def forward(self, pa, u, v=None, n=None):
        # confirm sampling / pmf estimation
        assert n is None or v is None, "v and n may not both be set"
        estimation = v is not None

        # default number of samples to draw
        if n is None:
            n = 1

        # confirm sizes are correct
        for k in self.v_size:
            assert pa[k].shape[-1] == self.v_size[k], (
                k,
                pa[k].shape[-1],
                self.v_size[k],
            )
        for k in self.u_size:
            assert u[k].shape[-1] == self.u_size[k], (k, u[k].shape[-1], self.u_size[k])

        if estimation:  # compute log P(v | pa_V, u_V)
            context = T.cat([pa[k] for k in self.v] + [u[k] for k in self.u], dim=-1)
            i = v
            o = self.nf(i, context)
            # minus because this will be negated again
            # (here, usually a loglikelihood is returned, which should be maximized)
            return -o.sum(dim=-1)
        else:  # sample from P(V)
            if self.v or self.u:
                v_context = [pa[k][0] if "Instr_" in k else pa[k] for k in self.v]

                if len(v_context) == 0:
                    u_context = [u[k][0] for k in self.u]
                else:
                    u_context = [u[k][0].reshape_as(v_context[0]) for k in self.u]

                context = T.cat(v_context + u_context, dim=-1)  # (n, dvu)
            else:
                context = T.empty(n, 0).to(next(self.parameters()).device)

            num_samples = context.shape[0] if context.shape[0] > 1 else context.shape[1]
            context = context if context.shape[0] > 1 else T.squeeze(context, 0)
            sample = self.nf.sample(num_samples=num_samples, context=context)[0] # context =  n x 1
            if sample.shape[0] != 1:
                sample= T.unsqueeze(sample, 0)
            return sample


if __name__ == "__main__":
    s = Continuous(dict(v1=2, v2=1), dict(u1=1, u2=2), 3)
    print(s)
    pa = {"v1": T.tensor([[1, 2], [3, 4.0]]), "v2": T.tensor([[5], [6.0]])}
    u = {"u1": T.tensor([[7.0], [8]]), "u2": T.tensor([[9, 10], [11, 12.0]])}
    v = T.tensor([[1, 2, 3], [4, 5, 6]]).float()

    print(s(pa, u, v))
    print(s(pa, u, n=1))

    import pandas as pd

    o = s(pa, u, n=10000)
    df = pd.DataFrame(o.detach().numpy())
    print(df)
